Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Apollo optimizer (https://arxiv.org/pdf/2412.05270) #196

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

murrellb
Copy link

This adds a draft of the low-rank Apollo optimizer that was preprinted yesterday (https://arxiv.org/pdf/2412.05270). Looks like it has some very nice properties, especially with the low memory footprint.

This works on the one case I've tested, but should be considered a WIP/draft until the codebase from the preprint is available, especially since there was a little guesswork from the manuscript alone. But I figured I'd open a PR in case anyone else wants to play with it, and if there is interest in merging I'll spend a bit more effort to check that it matches the "official" implementation when that arrives.

PR Checklist

  • Tests are added
  • Documentation, if applicable

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

If you edit this line https://github.com/FluxML/Optimisers.jl/blob/master/test/rules.jl#L11 then it CI will check that it does converge on some sample problems.

src/rules.jl Outdated
Comment on lines 609 to 613
struct Apollo{T1} <: AbstractRule
opt::T1
r::Int #Subspace rank
u::Int #Subspace update frequency (T in paper)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this must be AdamW, then you could do this, which supplies defaults:

Suggested change
struct Apollo{T1} <: AbstractRule
opt::T1
r::Int #Subspace rank
u::Int #Subspace update frequency (T in paper)
end
@def struct Apollo{T1} <: AbstractRule
opt = AdamW()
r = 10 # Subspace rank
u = 10 # Subspace update frequency (T in paper)
end

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted the user to be able to pass in either an Int or a function for rank (the latter where they can scale the rank based on the dim), so I've written some custom constructors with defaults instead of this approach.

src/rules.jl Outdated
init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x)
apply!(o::Apollo, state, x::AbstractArray{T,1}, dx) where T = apply!(o.opt, state, x, dx)

function init(o::Apollo, x::AbstractArray{T,2}) where T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For arrays of >2D (e.g. weight of Conv), should there be methods to reshape to matrix & reshape back?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is on my list. There is also this pesky assertion about the dimension ordering that means that some matrices will have to be transposed:

image

but I'm not sure a lazy transpose as W comes in and goes out will be optimal - might have to write a new path for those. I suspect we (as in "humanity") don't know if this actually helps, so I'm go make sure this can be controlled by a flag.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of these are included now.

@murrellb
Copy link
Author

I've now added an attempt at a "gradient norm growth limiter" because this paper used this in conjunction with Apollo. I think they apply this over the whole model, but here it will apply per tensor. This should still do the trick though, and it seems like it might be more generally useful for controlling loss spikes.

Apollo seems to error on some "gradient type" tests, where something is passed in that I didn't expect and don't understand. I think it doesn't even like calling size on these inputs? Haven't looked too closely, but maybe this is a familiar issue to you folks?

Other than that, I actually think that maybe we don't need to wait for the ref implementation to merge? Would Optimisers.jl consider an "experimental" tier of optmizers (eg. "ExperimentalApollo"), where they might not match the reference implementation and where their behavior may change in the future? There are many optimizers coming out, and this might be a decent strategy to stay cutting-edge for ones that are not yet battle-tested?

@murrellb murrellb changed the title Add Apollo optimizer (https://arxiv.org/pdf/2412.05270) - WIP Add Apollo optimizer (https://arxiv.org/pdf/2412.05270) Dec 12, 2024
@murrellb
Copy link
Author

murrellb commented Dec 12, 2024

The authors of the method haven't yet posted code, but they now link to this implementation on their github: https://github.com/zhuhanqing/APOLLO/tree/main

I think we could consider merging this? @mcabbott ? Edit: I've just realized I need to figure out how to make the random projection matrix update on-device.

@murrellb murrellb marked this pull request as draft December 13, 2024 06:45
@murrellb
Copy link
Author

This is now working on GPU. In combination with FluxML/Zygote.jl#1541 I can now train a 7 billion parameter transformer on a single 48gb A6000 GPU:

image

@murrellb murrellb marked this pull request as ready for review December 13, 2024 16:18
src/rules.jl Outdated
"""
GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true)

Gradient norm growth limiter from Chen et al. (https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al. (https://arxiv.org/pdf/2412.05270).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Gradient norm growth limiter from Chen et al. (https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al. (https://arxiv.org/pdf/2412.05270).
Gradient norm growth limiter from [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270).

src/rules.jl Outdated
Gradient norm growth limiter from Chen et al. (https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al. (https://arxiv.org/pdf/2412.05270).
With Optimisers.jl this will apply per-tensor, which may not be the same as the implementations in these papers. It still seems to help, but the ideal settings may vary.
This also introduces `m` a hard minimum on the gradient norm, and never rescales grads below this, preventing a tensor from getting "trapped" near zero.
This can be a fixed min, or scaled by the number of parameters in the tensor (with `paramscale_min = true`).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain the role of gamma?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the new version.

src/rules.jl Outdated
@@ -599,6 +599,136 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
return (mt, st, βt .* β), dx′
end


"""
GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the default value for m correspond to the original paper (i.e. m=0 i suppose)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m=0 makes sense when this is applied to the entire model, but could be fatal when applied tensor-wise. I think it is better to have non-footgun defaults, and make it clearer that this isn't a faithful reproduction?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept a non-zero default, but I've tweaked the docs to clarify that this method isn't quite the same as in those papers. (I also switched the "scaling m by the number of parameters" to using sqrt).

@murrellb
Copy link
Author

Thanks @CarloLucibello - I just woke up and noticed some issues with this PR, so I was about to convert it to draft! I need to sort out adjust for Apollo. I'll work through your GradNormGrowthLimiter comments too.

@murrellb murrellb marked this pull request as draft December 14, 2024 12:31
@murrellb murrellb marked this pull request as ready for review December 14, 2024 20:05
@zhuhanqing
Copy link

Hi all, I am one of the main authors of the APOLLO paper, Hanqing Zhu.
Thanks so much for bringing our APOLLO to FluxML. I am so sorry our code releases are still blocked by the internal review process. If you have any questions or want to know implementation details, please let us know, and we are more than happy to assist you with this integration and implementation process!

src/rules.jl Outdated
Comment on lines 606 to 609
Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but
with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow
from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor
from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this explain what it does do, mathematically, before explaining that it's different to some paper?

γ controls the maximum that the gradient norm can grow from one step to the next.

I don't know what this means without reading the code. Can you write like if norm(dx, 2) > γ * norm(dx_prev, 2) to explain the condition, and explain exactly what happens if this is violated?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

src/rules.jl Outdated


Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true)
Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), η, dim -> max(dim, rank), u, sort_dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you want max?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this.

src/rules.jl Outdated
Comment on lines 663 to 664
opt::T1
eta::T2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why store opt and eta?

Copy link
Author

@murrellb murrellb Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally just stored opt, but then getting adjust working seemed tricky (likely a skill issue on my part though). Options were to include all the other AdamW params directly in this struct, or have an AdamW that only applies to the low-rank moments (which doesn't use eta, so its eta is redundant), and a separate eta that gets tweaked by adjust. The latter seemed better because then you can just wrap an existing AdamW in this.

Edit: another reason for storing an AdamW is that the AdamW is used instead of Apollo on regular arrays. But I just realized that now "adjust" won't work for regular arrays. I'll try figuring this out...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing an AdamW seems fine, surely we can make adjust just work through onto the inner struct.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made adjust work on the inner Adam now, so have dropped the additional eta.

src/rules.jl Outdated
Comment on lines 666 to 667
u::T4 #Subspace update frequency (T in paper)
sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These have fixed types, right?

Suggested change
u::T4 #Subspace update frequency (T in paper)
sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first
u::Int # Subspace update frequency (T in paper)
sort_dims::Bool # Whether to swap the dims of x and dx when the second dim is smaller than the first

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup.

Comment on lines +699 to +700
dx = Broadcast.materialize(dx) #This is to stop the "gradient type" @lazy test from failing due to reshape.
dx = reshape(dx, size(x,1), nonfirstdims(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to materialize in matrix case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For everything except the whatever comes in during the "gradient type" test you don't need materialize. I wasn't 100% sure exactly what is coming in during those tests, so wasn't sure how to separate them from regular matrix/tensors. What do you suggest here?

src/rules.jl Outdated
s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ)
dx′′ = η * (dx .* reshape(s, 1, :)) + λ * x
if swapped
dx′′ = dx′′'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dx′′ = dx′′'
dx′′ = transpose(dx′′)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this sort of branching introduces type instability. IDK if we care but perhaps worth some thought. Maybe there's a nicer way to just store everything transposed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an optimization we can figure out later if it becomes an issue?

src/rules.jl Outdated
Comment on lines 723 to 725
Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ)
dx′′ = η * (dx .* reshape(s, 1, :)) + λ * x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines allocate a lot.

Rhat isn't used?

For the rest maybe it can be something like

sum1R2 = sum(abs2, R; dims=1)  # it's already the right shape, no need for [:] & reshape(s, 1, :)?
s = @. sqrt(sum1R2) / sqrt(Rhat + ϵ)
dx′′ = @lazy η * (dx * s) + λ * x  # one fused broadcast

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got something like this working, but the @lazy breaks things, so omitted for now.

src/rules.jl Outdated
end


Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this method just be created by giving a default to eta in the next one?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed via a different route.

src/rules.jl Outdated
paramscale_min::Bool
end

GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) = GradNormGrowthLimiter(γ, m, ϵ, throw, paramscale_min)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have greek-letter keyword options, nor field names -- the API should never ask the user to type these. They are used only in documentation / as local variables. Probably the first 3 should be positional.

Bikeshedding names bit, to avoid overly long things, the constructor could be:

Suggested change
GradNormGrowthLimiter= 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) = GradNormGrowthLimiter(γ, m, ϵ, throw, paramscale_min)
NormGrowLimit= 1.1, m = 1e-3, ε = 1e-8; throw = true, scale = true) = NormGrowLimit(γ, m, ε, throw, scale)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with NormGrowthCap here.

src/rules.jl Outdated
Comment on lines 612 to 614
γ::Float64
m::Float64 #Min grad norm, to stop a tensor getting stuck near zero
ϵ::Float64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't allow unicode field names, suggest:

Suggested change
γ::Float64
m::Float64 #Min grad norm, to stop a tensor getting stuck near zero
ϵ::Float64
gamma::Float64
mu::Float64 # Min grad norm, to stop a tensor getting stuck near zero
epsilon::Float64

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but changed the variable names to avoid eg. gamma.

@murrellb murrellb marked this pull request as draft December 15, 2024 22:00
@murrellb murrellb marked this pull request as ready for review December 16, 2024 02:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants